'''
Author: SlytherinGe
LastEditTime: 2021-09-21 16:14:35
'''
import torch
import numpy as np
import os
import sys
import xml.etree.ElementTree as ET
import cv2

sys.path.append('/media/gejunyao/Disk/Gejunyao/develop/toolbox-for-voc-dataset/') 

import matplotlib.pyplot as plt

from area_of_interest.aoi_box_generate import AOI_Generator
from image_process.os_cfar import os_cfar, bbox_guided_os_cfar

class CfarBasedAOIMaskGenerator(AOI_Generator):

    def __init__(self, anno_root, img_root):
        super().__init__(anno_root, None)
        self.img_root = img_root
        self.img_files = os.listdir(img_root)
        self.img_files.sort()

    def _get_mmdet_data_from_anno(self, index):
        index = index
        anno = self.get_anno_from_index(index)
        bboxes = []
        for bndbox in anno['object']:
            bbox = (int(bndbox['bndbox']['xmin']), int(bndbox['bndbox']['ymin']), int(bndbox['bndbox']['xmax']), int(bndbox['bndbox']['ymax']))
            bboxes.append(bbox)
        bboxes = np.array(bboxes)
        shape = (int(anno['size']['height']), int(anno['size']['width']), int(anno['size']['depth']))
        img = cv2.imread(os.path.join(self.img_root, self.img_files[index]))
        return dict(gt_bboxes=bboxes, img_shape=shape, img=img)

    def __call__(self, results):
        # get bboxes data from results['gt_bboxes'], which is a numpy array
        # get img shape from results['img_shape'], which is a tuple(h, w, c)
        gt_bboxes = results['gt_bboxes'].astype(np.intp)
        img_shape = results['img_shape']
        img = results['img']
        k = np.ones((2,2), np.uint8)
        canvas = np.zeros((img_shape[0], img_shape[1], gt_bboxes.shape[0]))  
        for i in range(gt_bboxes.shape[0]):    
            alpha = 2
            content = 0
            size = int((gt_bboxes[i,3] - gt_bboxes[i,1]) * (gt_bboxes[i,2] - gt_bboxes[i,0]))
            while (content/size < 0.4) and (alpha > 0.5):
                canvas[:,:,i] = bbox_guided_os_cfar(img, gt_bboxes[i], 5, alpha)
                content = len(np.nonzero(canvas[:,:,i])[0])
                alpha -= 0.1
        res = (canvas.sum(axis=2) > 0)
        res = cv2.morphologyEx(res.astype(np.uint8), cv2.MORPH_OPEN, k)
        res = cv2.morphologyEx(res.astype(np.uint8), cv2.MORPH_CLOSE, k)
        return res


if __name__ == "__main__":
    import matplotlib.pyplot as plt
    import time
    import os

    # FORK_TIME = 4

    ANNO_ROOT = '/media/gejunyao/Disk1/Datasets/CAS-OpenSARShip/ship_detection_online/VOC2012/Annotations'
    IMG_ROOT = '/media/gejunyao/Disk1/Datasets/CAS-OpenSARShip/ship_detection_online/VOC2012/JPEGImages'
    OUT_ROOT = '/media/gejunyao/Disk/Gejunyao/exp_results/visualization/results/cfar/os-d-cfar-cas-sarship'

    cfar_aoi_generator = CfarBasedAOIMaskGenerator(ANNO_ROOT, IMG_ROOT)
    # for _ in range(FORK_TIME):
    #     img_cnt = cfar_aoi_generator.data_size()
    #     pid = os.fork()
    #     if pid == 0:
    #     # child
    #         cfar_aoi_generator.img_files=cfar_aoi_generator.img_files[:int(img_cnt/2)]
    #         cfar_aoi_generator.anno=cfar_aoi_generator.anno[:int(img_cnt/2)]
    #         cfar_aoi_generator.anno_files=cfar_aoi_generator.anno_files[:int(img_cnt/2)]
    #     elif pid > 0:
    #     # father
    #         cfar_aoi_generator.img_files=cfar_aoi_generator.img_files[int(img_cnt/2):]
    #         cfar_aoi_generator.anno=cfar_aoi_generator.anno[int(img_cnt/2):]
    #         cfar_aoi_generator.anno_files=cfar_aoi_generator.anno_files[int(img_cnt/2):]
    #     else:
    #         print('fork failed')
    # total = cfar_aoi_generator.data_size()
    # for i in range(total):
    #     data = cfar_aoi_generator._get_mmdet_data_from_anno(i)
    #     t1 = time.time()
    #     res = cfar_aoi_generator(data)
    #     save_path = os.path.join(OUT_ROOT, cfar_aoi_generator.img_files[i])
    #     cv2.imwrite(save_path, res*255)
    #     t2 = time.time()
    #     print('{}/{} [{}] time: {:3.3}s'.format(i,total,cfar_aoi_generator.img_files[i],t2-t1))
    # os.wait()
    
    out_files = os.listdir(OUT_ROOT)
    total_size = cfar_aoi_generator.data_size()
    for i in range(total_size):
        img_file = cfar_aoi_generator.img_files[i]
        print("{}/{}".format(i+1, total_size))
        if (img_file in out_files)==False:
            print('missing {}'.format(img_file))
            data = cfar_aoi_generator._get_mmdet_data_from_anno(i)
            t1 = time.time()
            res = cfar_aoi_generator(data)
            save_path = os.path.join(OUT_ROOT, cfar_aoi_generator.img_files[i])
            cv2.imwrite(save_path, res*255)
            t2 = time.time()
            print('{}/{} [{}] time: {:3.3}s'.format(i,total_size,cfar_aoi_generator.img_files[i],t2-t1))            

